自定义回调函数和学习率调度器

Note

回调函数和学习率调度器也是可以自定义的,本节我们举例说明。

回调函数

from tensorflow import keras


class PrintValTrainRatioCallback(keras.callbacks.Callback):
    # 在每个epoch后打印val loss和train loss的比值
    # 若想batch-wise操作,自定义on_batch_end()即可
    def on_epoch_end(self, epoch, logs):
        print("\nval/train: {:.2f}".format(logs["val_loss"] / logs["loss"]))

学习率调度器

def exponential_decay(lr0, s):
    # 指数学习率
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 ** (epoch / s)
    return exponential_decay_fn

# 第一步:实现一个以epoch为参数的学习率函数
exponential_decay_fn = exponential_decay(lr0=0.01, s=20)
# 第二步:将函数传给keras.callbacks.LearningRateScheduler
# 即可像其他回调函数那样在fit时使用
lr_scheduler = keras.callbacks.LearningRateScheduler(exponential_decay_fn)